require(tidyverse)
require(data.table)
require(igraph)
require(tidygraph)
require(Matrix)
library(glue)
require(bigalgebra)
## HELPERS -----
`%notin%` <- function(x,y){!(x %in% y)}

preprocess_text <- function(x){
  x = str_replace_all(x,'\\s+',' ')
  x = tolower(x)
  return(x)
}

source_lines <- function(file, lines){
  source(textConnection(readLines(file)[lines]))
}

make_transition_matrix <- function(g,
                                   exog,
                                   exog.score.name='{Insert Exog Score Name}'){
  "
 Takes a tidygraph and list of exogenous orgs 
 Returns a (sparse) transition matrix 'M'
 
 TO DO: centrality_weighted_nodes should be a vector of nodes whose positions
  will be calculated based on the centrality of the nodes they are connected to.
  The intended use case is where we assume briefs are positioned closer to more central
  orgs that signed. In that case, centrality_weighted_nodes would be a vector of brief nodes from g.
 "
  #produce some summary statistics
  print(glue("Building transition matrix with {exog.score.name} as exogenous scores"))
  print(glue("Input graph has {val} nodes",
             val=g %>% activate(nodes) %>% pull() %>% length()))
  print(glue("{val} of these nodes are orgs",
        val=g %>% activate(nodes) %>% pull(name) %>% str_detect(pattern='B-\\d+',negate = T) %>% sum()))
  print(glue("{val} of these nodes are briefs",
        val=g %>% activate(nodes) %>% pull(name) %>% str_detect(pattern='B-\\d+',negate = F) %>% sum()))
  print(glue("There are {val} orgs in exogenous source orgs on the graph",
             val=exog %>% length()))
  print(glue("***Unfortunately not all endogenous nodes are connected to exogenous nodes***"))
  clu=igraph::components(g)
  print(glue("The graph has {val} connected components",val=clu$no))
  scorable.components = clu$membership[exog] %>% unique() 
  unscorable.components = which(1:(clu$no) %notin% scorable.components)
  print(glue("There are {val} components that are scoreable (i.e. contain at least one exogenous org)",
             val=scorable.components %>% length()))
  is.scorable.node = clu$membership %in% scorable.components
  print(glue("The scorable components contain {val} nodes",
             val=is.scorable.node %>% sum()))
  is.exogenous = (clu$membership %>% names()) %in% exog
  statuses=data.frame(
    node=clu$membership %>% names(),
    is.brief = clu$membership %>% names() %>% str_detect(pattern='B-\\d+'),
    is.scorable=is.scorable.node,
    is.exogenous) %>%
    group_by(is.brief,
             is.scorable=factor(is.scorable,c(F,T),c("Not Scorable","Scorable")),
             is.exogenous=factor(is.exogenous,c(F,T),c("Endogenous","Exogenous"))) %>%
    count()
  
  #particular interest in what might be thought of as isolates, single brief-org pairs
  clusters.of.two = (1:clu$no)[clu$csize==2]
  
  complex.statuses=data.frame(
    node=clu$membership %>% names(),
    is.isolated=clu$membership %in% clusters.of.two,
    is.brief = clu$membership %>% names() %>% str_detect(pattern='B-\\d+'),
    is.scorable=is.scorable.node,
    is.exogenous) %>%
    group_by(is.brief,
             is.isolated=factor(is.isolated,c(F,T),c("2+ Node Component","2 Node Component (Org+Brief)")),
             is.scorable=factor(is.scorable,c(F,T),c("Not Scorable","Scorable")),
             is.exogenous=factor(is.exogenous,c(F,T),c("Endogenous","Exogenous"))) %>%
    count() 

  #contingency.table.orgs = statuses %>% filter(!is.brief) %>%
  #  ungroup() %>%
  #  select(-is.brief) %>%
  #  pivot_wider(names_from='is.scorable',values_from='n',values_fill=0) %>% 
  #  rowwise() %>%
  #  mutate(Total=sum(across(2:3))) %>%
  #  column_to_rownames('is.exogenous') %>%
  #  add_row(.,tibble_row(!!!colSums(.))) %>%
  #  rownames_to_column() %>%
  #  mutate(rowname=ifelse(rowname=='...3','Total',rowname)) %>%
  #  column_to_rownames()
  #print(glue("Here is a contingency table of orgs"))
  #print(contingency.table.orgs)
  #print(glue("Here is a contiengency table for briefs"))
  #contingency.table.briefs = statuses %>% filter(is.brief) %>%
  #  ungroup() %>%
  #  select(-is.brief) %>%
  #  pivot_wider(names_from='is.scorable',values_from='n',values_fill=0) %>% 
  #  rowwise() %>%
  #  mutate(Total=sum(across(2:3))) %>%
  #  column_to_rownames('is.exogenous') 
  #print(contingency.table.briefs)
  # first make the adjacency matrix of the graph
  g_as_mat = g %>% 
    as_adjacency_matrix(attr='weight') 
  #get the row sums
  N = g_as_mat %>% rowSums()
  #subset_without the junk
  N_no_missing = N[N>0]
  g_as_mat_no_missing = g_as_mat[N>0,N>0]
  W = Matrix(data=g_as_mat_no_missing/N_no_missing,sparse=T)
  #now replace the exogenous with 1s along the diagonal 0 elsewhere
  exog.index = which(colnames(W) %in% exog)
  scorable.nodes = (clu$membership %>% names())[is.scorable.node]
  endog.index = which((colnames(W) %notin% exog) & (colnames(W) %in% scorable.nodes))
  W[exog.index,]=0
  diag(W[exog.index,exog.index])=1
  #Now let's reorganize M
  M= W[c(exog.index,endog.index),c(exog.index,endog.index)]
  return(M)
}



estimate_markovian_steady_state <- function(M,log="estimation_process.txt",
                                            precision.exp = -2){
  "
  Uses big matrix package to iteratively estimate the steady state.
  
  Stack Overflow, textbooks, common sense suggest that inversion with solve isn't going to be stable.
  Iterative methods are generally preferred with this size data.
  
  The idea here is that to calculate e.g  M^2^3 = ((M x M) x (M x M)) x ((M x M) x (M x M)) you only need three calculations.
  So implicitly M^2^10 = M^1024 is pretty big...
  "
  #calculate precision
  number.of.exog.orgs = (diag(M)==1) %>% sum()
  precision = (number.of.exog.orgs)**precision.exp
  
  big.M = as.big.matrix(as.matrix(M))
  STO = big.M
  #sink(file=log)
  start.time = proc.time()
  for (i in 1:30){
    LASTITER = STO
    print("Saved old matrix")
    print(timetaken(start.time))
    STO = STO %*% STO
    print("Squared old matrix")
    print(timetaken(start.time))
    maxdiff = max(abs(as.matrix(STO-LASTITER)))
    print("Maximim difference calculated")
    print(maxdiff)
    print(paste0("Iter ",i," Finished"))
    print(timetaken(start.time))
    #stopping condition
    if (maxdiff<precision){
      break
    }
  }
  #sink()
  return(STO)
}


make_orgs_only_graph <- function(){
  "Returns a graph object where there are no briefs only orgs"
  #acnet Edges has a lot of columns that list the caseIDs and caseNames of the connections
  #we only need these components for our purposes
  #fread("acnet_Edges_v7.csv",nrow=5)
  #fread(input = "DT.csv",nrow=5)
  edges = fread("data/acnet_Edges_Full.csv",drop=1:2) %>%
    as_tibble() %>%
    distinct() %>%
    mutate(briefID = briefs) %>%
    select(-briefs) %>%
    group_by(org1=org.min, org2=org.max) %>%
    summarize(weight=n(),.groups="drop")
  ## Load and transform raw amicus data
  #load vertices
  vertices = read.csv("data/acnet_Vertices.csv") %>% 
    rename(orgID=orgId) %>%
    transmute(orgID,name=preprocess_text(properName)) %>%
    as_tibble()
  
  #reorder
  vertices = vertices[order(vertices$orgID),]
  
  
  g = graph_from_data_frame(edges,
                            directed=F,
                            vertices=vertices
                            ) %>%
    as_tbl_graph()
  return(g)
}

make_orgs_and_briefs_graph_old <- function(){
  source("code/acnet.R")
  q <- acnet.web.query(minRange=1910, maxRange=2021)
  raw_briefs = fromJSON(q)[[1]] %>% 
    mutate(orgID = strsplit(as.character(orgIDs), ",", fixed=T)) %>% 
    unnest(orgID) %>%
    mutate(orgID=as.numeric(orgID))
  
  vertices = read.csv("data/acnet_Vertices_v7.csv") %>%
    rename(orgID=orgId) %>%
    filter(orgID!=10490) %>% 
    transmute(orgID,name=preprocess_text(properName)) %>%
    as_tibble()
  
  briefs = raw_briefs %>%
    inner_join(vertices,by = "orgID")
  
  nodes = briefs %>%
    transmute(name,briefID,weight=1) 
  
  g = graph_from_data_frame(nodes,
                            directed=F) %>%
    as_tbl_graph()
  return(g)
  #stats 15,376 organizations, 14387 briefs and 40,089 edges
  
}


make_orgs_and_briefs_graph <- function(){
  #set of briefs with 
  briefs = fread("data/acnet_Edges_Full.csv") %>%
    transmute(briefs,org1=org.min,org2=org.max) %>%
    as_tibble() %>%
    distinct() 
  
  edges = 
    bind_rows(
      data.frame(briefID=briefs$briefs,
                 orgID=briefs$org1),
      data.frame(briefID=briefs$briefs,
                 orgID=briefs$org2),      
    ) %>%
    distinct()
  
  
  orgs = read.csv("data/acnet_Vertices.csv") %>% 
    rename(orgID=orgId) %>%
    transmute(orgID,name=preprocess_text(properName)) %>%
    as_tibble() 
  
  edges_with_names = edges %>%
    inner_join(orgs,by = "orgID") %>%
    transmute(name,briefID,orgID,weight=1) 
  
  g = graph_from_data_frame(edges_with_names,
                            directed=F) %>%
    as_tbl_graph()
  
  return(g)
}


get_centrality_weights_from_org_graph <- function(centrality.func=centrality_eigen){
  "
  Given a centrality function centrality.func, extract weights on orgs
  
  All weights adjusted by Machine double epsilon to avoid division by 0 issues.
  "
  g0 = make_orgs_only_graph()
  org_weights = g0 %>% 
    mutate(centrality.weight=centrality.func()) %>%
    activate(nodes) %>%
    as.tibble()
  
  w=org_weights$centrality.weight + .Machine$double.eps
  return(w)
}


make_weighted_transition_matrix <- function(g,exog,w=org_weights){
  "
  Takes 
    - G, an undirected tidygraph 
    - exog, a list of exogenous orgs
    - w,a named vector of org weights. If a node in G has no name
    
    See the example_centrality.R script for a toy illustration.
  .
  "
  #form adjacency matrix
  A = g %>% as_adjacency_matrix()
  #fix weights to include 0 for briefs
  w_expanded = w[rownames(A)]
  names(w_expanded) = rownames(A)
  w_expanded[is.na(w_expanded)] <- 0
  #form the part of the matrix that uses org ideology to infer brief ideology
  org_part_unnormalized = t(A * w_expanded) %>% drop0()
  org_part = 
    org_part_unnormalized/(rowSums(org_part_unnormalized) + .Machine$double.xmin)
  #form the part of the matrix that uses brief ideology to infer org ideology
  brief_part = t(org_part)/(t(org_part) %>%rowSums() + .Machine$double.xmin)
  #weighted matrix,  
  W_including_singletons=org_part + brief_part
  "
  Now something interesting happens for orgs that have no weight.
  take the 'us brewers association' or row 3 in A. It signed one amicus brief
  (in row 15379)
  
  (A[3,] > 0) %>% sum() #1
  (A[,15379] > 0) %>% sum() #1
  z = c(3,15379)
  
  > A[z,z]
  2 x 2 sparse Matrix of class 'dgCMatrix'
                         us brewers association 1919-042-B-001
  us brewers association                      .              1
  1919-042-B-001                              1              .

  > w_expanded[z]
  us brewers association         1919-042-B-001 
                       0                      0 
  
  > W_including_singletons[z,z]
  2 x 2 sparse Matrix of class 'dgCMatrix'
                         us brewers association 1919-042-B-001
  us brewers association                      .              .
  1919-042-B-001                              .              .  

  --- we can safely drop this stuff
  "
  valid_rows = (W_including_singletons %>% rowSums()) > 0 
  W = W_including_singletons#[valid_rows,valid_rows]
  #now the work is very similar to the prior make_transition_matrix function
  #we must ensure exogenous orgs stay exogenous.
  #otherwise we now have the weights we want
  exog.index = which(colnames(W) %in% exog)
  endog.index = which(colnames(W) %notin% exog)
  W[exog.index,]=0
  diag(W[exog.index,exog.index])=1
  #Now let's reorganize M
  M= W[c(exog.index,endog.index),c(exog.index,endog.index)]
  return(M)
}





ideal_point_estimates <- function(M.inf,
                                  exog,
                                  score_vector){
  "
  Calculates the steady state given the ideal points
  "
  names(score_vector) = exog
  #score_vector is too small, doesn't have endog orgs
  score_vector_full =  score_vector[rownames(M.inf)]
  names(score_vector_full) = rownames(M.inf)
  #overwrite NA positions as 0, so matrix algebra will work as it should
  score_vector_full[is.na(score_vector_full)] <- 0
  #calculate estimate
  estimate = (M.inf %*% score_vector_full)[,1]
  return(estimate)
}

estimate_ideology_without_steady_state <- function(M,exog,score_vector,precision.exp=-2){
  "
  Calculates the ideological output without the steady state.
  
  The idea here is to calculate M*(M*(M * ... *(M (M* p)))))
  
  Practically this requires less memory.
  "
  #setup the vector
  
  #initially, fix the alignment issues between score and M
  names(score_vector) = exog
  #score_vector is too small, doesn't have endog orgs
  score_vector_full =  score_vector[rownames(M)]
  names(score_vector_full) = rownames(M)
  #overwrite NA positions as 0, so matrix algebra will work as it should
  score_vector_full[is.na(score_vector_full)] <- 0
  
  #setup iteration
  number.of.exog.orgs = (diag(M)==1) %>% sum()
  precision = (number.of.exog.orgs)**precision.exp
  
  STO = score_vector_full
  
  #if(!is.null(log)){
  # #sink(file=log)
  #}
  
  start.time = proc.time()
  for (i in 1:1e5){
    LASTITER = STO
    # print("Saved old score vector")
    #print(timetaken(start.time))
    STO = M %*% STO
    #print("Updated score vector")
    #print(timetaken(start.time))
    maxdiff = max(abs(STO[,1]-LASTITER))
    #print("Maximim difference calculated")
    #print(maxdiff)
    #print(paste0("Iter ",i," Finished"))
    #print(timetaken(start.time))
    #stopping condition
    if (maxdiff<precision){
      break
    }
  }
  #sink()
  return(STO)
}


bootstrapped_ideal_point_estimates <- function(M.inf,
                                               exog,
                                               score_vector,
                                               se_vector,
                                               number_of_iterations_to_bootsrap=100,
                                               seed=1){
  "
  Returns bootstrapped amicus scores given uncertainty in the initial exogenous scores
  
  M.inf - the converged steady state transition matrix
  exog - a list of the names of the orgs that are exogenous
  se_vector - standard errors for the purpose of noising the estimates
  "
  #setup
  set.seed(seed)
  number_of_orgs_to_estimate = dim(M.inf)[1]
  
  #produce noise and shape it correctly
  raw_noise = matrix(rnorm(number_of_orgs_to_estimate*number_of_iterations_to_bootsrap),
                     nrow=number_of_orgs_to_estimate,
                     ncol=number_of_iterations_to_bootsrap)
  rownames(raw_noise) = rownames(M.inf)
  
  #setup standard errors
  names(se_vector) = exog
  #se_vector is too small, doesn't have endog orgs
  se_vector_full =  se_vector[rownames(raw_noise)]
  names(se_vector_full) = rownames(raw_noise)
  
  #now we do multiply standard noise by se_vector to get the right variance in exogenous noise
  noise = se_vector_full*raw_noise
  
  #now we must add the ideal points
  names(score_vector) = exog
  #score_vector is too small, doesn't have endog orgs
  score_vector_full =  score_vector[rownames(raw_noise)]
  names(score_vector_full) = rownames(raw_noise)
  #produce bootstrapped initial scores
  bootstrapped.initial.scores = score_vector_full + noise
  #overwrite NA positions as 0, so matrix algebra will work as it should
  bootstrapped.initial.scores[is.na(bootstrapped.initial.scores)] <- 0
  
  
  ### ESTIMATION
  bootstrapped.final.scores = M.inf %*% bootstrapped.initial.scores
  exog.orgs = rownames(M.inf)[diag(M.inf) == 1]
  endog.orgs = rownames(M.inf)[diag(M.inf) != 1]
}


loo_crossval = function(graph, exog, score_vector, weights=FALSE){
  # Set up
  idx = 1:length(exog)
  crossval_out = list()
  M = make_transition_matrix(g,exog)
  canonical_row_order = rownames(M)
  
  # Loop through the exogenous scores
  for(i in 1:length(exog)){
  #for(i in 1:20){ ## for testing
    
    if(!weights){
      #rebuild transition matrix assuming i is endogenous
      M = make_transition_matrix(g,exog[-c(i)])
    } else if(weights) {
      org_weights = get_centrality_weights_from_org_graph(centrality_eigen)
      #rebuild transition matrix assuming i is endogenous
      M = make_weighted_transition_matrix(g,exog = exog[-c(i)],org_weights)
    } else {
      stop("weights must be true or false")
    }
    p = estimate_ideology_without_steady_state(M, 
                                               exog = exog[-c(i)],
                                               score_vector = score_vector[-c(i)])
    
    p_canonical_order = p[,1]
    p2 = data.frame(orgname = names(p_canonical_order),
                    score = p_canonical_order)
    colnames(p2)[2] = paste0("score_", i)
    crossval_out[[i]] = p2
    print(i)
  }
  
  crossval_scores = out = Reduce(function(...) merge(..., all=T), crossval_out)
  rownames(crossval_scores) = crossval_scores$orgname
  crossval_scores = crossval_scores %>% select(-orgname)
  
  # Next, extract the crossval-estimated score for each iteration
  crossval_vec_list = sapply(1:length(exog), FUN = function(x){
    crossval_scores[rownames(crossval_scores)==exog[x],x]
  })
  crossval_vec_list[sapply(crossval_vec_list, length)==0] = NA
  crossval_vec = unlist(crossval_vec_list)
  return(data.frame(crossval_ests = crossval_vec, truth = score_vector))
}



# 
# idx = 1:nrow(bonica_ses)
# M = make_transition_matrix(g,bonica_ses$amicus)
# canonical_row_order = rownames(M)
# crossval_out = list()
# for(i in 1:nrow(bonica_ses)){
#   #for(i in 1:20){ ## for testing
#   M = make_transition_matrix(g,bonica_ses$amicus[-c(i)])
#   p = estimate_ideology_without_steady_state(M, 
#                                              exog = bonica_ses$amicus[-c(i)],
#                                              score_vector = bonica_ses$score[-c(i)],
#                                              log = paste0("crossval_logs/crossval_", i, ".txt"))
#   p_canonical_order = p[canonical_row_order,1]
#   crossval_out[[i]] = p_canonical_order
#   print(i)
# }
# 
# crossval_scores = as.data.frame(do.call(cbind, crossval_out))
# 
# save(crossval_scores, file="crossval_results.RData")
# 
# # Next, extract the crossval-estimated score for each iteration
# crossval_vec_list = sapply(1:nrow(bonica_ses), FUN = function(x){
#   crossval_scores[rownames(crossval_scores)==bonica_ses$amicus[x],x]
# })
# 
# crossval_vec = unlist(crossval_vec_list)
# 
